"""
Phase 8: Reviewer #2 Response — Automation Paper
=================================================
Addresses all reviewer issues:

Part 1 — Variable Definitions Fix (Issue 3.1)
  - Add capital_output_ratio (K/Y = rnna/rgdpo) from PWT as true capital intensity
  - Rerun baseline regressions with all 4 properly labeled proxies:
    I/Y (gross_investment_gdp), K/Y (rnna/rgdpo), K/L (capital_per_worker), labor_productivity

Part 2 — Income Tercile Reanalysis (Issue 3.3)
  - Rerun tercile regressions with 4 clean proxies
  - Identify which proxy drives which tercile pattern

Part 3 — Mediation/Suppression Taxonomy Fix (Issue 2.1)
  - Recompute mediation with properly labeled variables
  - Produce clean summary distinguishing mediation (<100%) from suppression (>100%)

Part 4 — Negative R² Explanation (Issue 3.2)
  - Produce table showing which models have negative R²
  - Add interpretation note

Part 5 — Variable Definition Appendix Table (Issue 3.1)
  - Generate clean table mapping variable names to sources and transformations
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

OECD = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
}

CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'kaopen']
CONTROLS_FULL = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    return f"{val:.4f}{stars(p)}", f"({se:.4f})"


def run_gls(df, y_var, x_vars, label):
    """Run PanelGLS, return results dict or None."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    gls = PanelGLS()
    y = sub[y_var].values
    X = sub[x_vars].values
    try:
        gls.fit(y, X, sub['iso3'].values, sub['year'].values)
    except Exception as e:
        print(f"  {label}: GLS failed ({e}), skipping")
        return None

    result = {
        'model': label,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
        'rho': gls.rho,
    }
    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
    for i, name in enumerate(x_vars):
        sig = stars(gls.pvalues[i])
        print(f"    {name:30s} {gls.beta[i]:8.4f} ({gls.se[i]:.4f}) {sig}")
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    return result


def write_table(results, filename, title, notes=None):
    """Write results as markdown table."""
    if not results:
        return

    lines = [f"# {title}\n"]

    all_vars = []
    for r in results:
        for k in r:
            if k.endswith('_coef'):
                vname = k.replace('_coef', '')
                if vname not in all_vars:
                    all_vars.append(vname)

    model_labels = [r['model'] for r in results]
    header = "| Variable | " + " | ".join(model_labels) + " |"
    sep = "|:---|" + "|".join(["---:" for _ in results]) + "|"
    lines.append(header)
    lines.append(sep)

    for var in all_vars:
        coef_row = f"| {var} |"
        se_row = "| |"
        for r in results:
            if f'{var}_coef' in r:
                c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                coef_row += f" {c} |"
                se_row += f" {s} |"
            else:
                coef_row += " |"
                se_row += " |"
        lines.append(coef_row)
        lines.append(se_row)

    lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")
    n_row = "| N |"
    r2_row = "| R² |"
    nc_row = "| Countries |"
    for r in results:
        n_row += f" {r['n_obs']} |"
        r2_row += f" {r['r_squared']:.4f} |"
        nc_row += f" {r['n_countries']} |"
    lines.append(n_row)
    lines.append(r2_row)
    lines.append(nc_row)

    lines.append("\n*Panel GLS with Prais-Winsten AR(1) correction, country and year FE.*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")
    if notes:
        lines.append(f"\n{notes}")

    path = OUT_TABLES / filename
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ══════════════════════════════════════════════════════════════════════
# DATA PREPARATION
# ══════════════════════════════════════════════════════════════════════

def prepare_data():
    """Load panel and add K/Y from PWT."""
    df = pd.read_csv(DATA / "automation_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    # Add rgdpo from PWT to compute K/Y
    pwt_path = PROJECT_DIR / "data" / "raw" / "pwt1001.csv"
    if pwt_path.exists():
        try:
            pwt = pd.read_stata(pwt_path)
        except Exception:
            pwt = pd.read_csv(pwt_path)

        pwt = pwt.rename(columns={'countrycode': 'iso3'})
        pwt['year'] = pd.to_numeric(pwt['year'], errors='coerce')

        if 'rgdpo' in pwt.columns:
            pwt['rgdpo'] = pd.to_numeric(pwt['rgdpo'], errors='coerce')
            pwt_merge = pwt[['iso3', 'year', 'rgdpo']].dropna()
            df = df.merge(pwt_merge, on=['iso3', 'year'], how='left')
            print(f"  Merged rgdpo: {df['rgdpo'].notna().sum()} obs")

            # Compute K/Y = rnna / rgdpo
            mask = (df['rnna'] > 0) & (df['rgdpo'] > 0) & df['rnna'].notna() & df['rgdpo'].notna()
            df.loc[mask, 'capital_output_ratio'] = df.loc[mask, 'rnna'] / df.loc[mask, 'rgdpo']
            n_ky = df['capital_output_ratio'].notna().sum()
            print(f"  Computed capital_output_ratio (K/Y = rnna/rgdpo): {n_ky} obs")
            print(f"    Mean: {df['capital_output_ratio'].mean():.3f}, "
                  f"Median: {df['capital_output_ratio'].median():.3f}")
        else:
            print("  WARNING: rgdpo not found in PWT")
            df['capital_output_ratio'] = np.nan
    else:
        print("  WARNING: PWT cache not found")
        df['capital_output_ratio'] = np.nan

    return df


# ══════════════════════════════════════════════════════════════════════
# PART 1: BASELINE WITH CORRECTED VARIABLE LABELS
# ══════════════════════════════════════════════════════════════════════

def part1_corrected_baselines(df):
    """Run baseline Z → proxy regressions with all 4 properly labeled proxies."""
    print("\n" + "=" * 70)
    print("PART 1: CORRECTED BASELINE REGRESSIONS")
    print("=" * 70)

    # Variable mapping: column -> display label
    proxies = {
        'capital_intensity': 'I/Y (Invest/GDP)',
        'capital_output_ratio': 'K/Y (Capital/Output)',
        'capital_per_worker': 'K/L (Capital/Worker)',
        'labor_productivity': 'GDP per capita',
    }

    results = []
    z_vars = ['Z_1', 'Z_2', 'Z_3']

    for col, label in proxies.items():
        if col not in df.columns or df[col].notna().sum() < 100:
            print(f"  Skipping {col}: insufficient data")
            continue

        print(f"\n--- {label} ({col}) ---")
        r = run_gls(df, col, z_vars + CONTROLS, label)
        if r:
            results.append(r)

    write_table(results, "phase8_corrected_baselines.md",
                "Corrected Baseline: Demographics → Automation Proxies",
                notes=("*I/Y = gross_investment_gdp (investment effort). "
                       "K/Y = rnna/rgdpo from PWT (capital-output ratio). "
                       "K/L = rnna/(emp×1M) from PWT (capital per worker). "
                       "GDP/capita = gdp_pc_ppp from WDI.*"))

    # OECD vs non-OECD split
    print("\n--- OECD vs non-OECD split ---")
    results_split = []
    for col, label in proxies.items():
        if col not in df.columns or df[col].notna().sum() < 100:
            continue
        for subsample, sub_label in [(df[df['iso3'].isin(OECD)], 'OECD'),
                                      (df[~df['iso3'].isin(OECD)], 'non-OECD')]:
            r = run_gls(subsample, col, z_vars + CONTROLS,
                        f"{label[:8]}: {sub_label}")
            if r:
                results_split.append(r)

    write_table(results_split, "phase8_baselines_oecd_split.md",
                "OECD vs non-OECD: Demographics → Automation Proxies")

    return results


# ══════════════════════════════════════════════════════════════════════
# PART 2: INCOME TERCILE REANALYSIS
# ══════════════════════════════════════════════════════════════════════

def part2_income_terciles(df):
    """Rerun income tercile analysis with all 4 clean proxies."""
    print("\n" + "=" * 70)
    print("PART 2: INCOME TERCILE REANALYSIS (4 PROXIES)")
    print("=" * 70)

    # Compute terciles
    valid = df['gdp_pc_ppp'].notna()
    df.loc[valid, 'income_tercile'] = pd.qcut(
        df.loc[valid, 'gdp_pc_ppp'], 3, labels=['Low', 'Mid', 'High']
    )
    for t in ['Low', 'Mid', 'High']:
        n = (df['income_tercile'] == t).sum()
        nc = df[df['income_tercile'] == t]['iso3'].nunique()
        med_gdppc = df.loc[df['income_tercile'] == t, 'gdp_pc_ppp'].median()
        print(f"  {t}: {n} obs, {nc} countries, median GDP/cap ${med_gdppc:,.0f}")

    proxies = {
        'capital_intensity': 'I/Y',
        'capital_output_ratio': 'K/Y',
        'capital_per_worker': 'K/L',
        'labor_productivity': 'GDP/cap',
    }

    z_vars = ['Z_1', 'Z_2', 'Z_3']

    # Run per proxy, one table per proxy with Full + 3 terciles
    all_results = {}
    for col, short_label in proxies.items():
        if col not in df.columns or df[col].notna().sum() < 100:
            continue

        print(f"\n--- {short_label} ({col}) ---")
        proxy_results = []

        # Full sample
        r = run_gls(df, col, z_vars + CONTROLS, f"{short_label}: Full")
        if r: proxy_results.append(r)

        # By tercile
        for t in ['Low', 'Mid', 'High']:
            sub = df[df['income_tercile'] == t].copy()
            r = run_gls(sub, col, z_vars + CONTROLS, f"{short_label}: {t}")
            if r: proxy_results.append(r)

        all_results[col] = proxy_results

    # Write combined table (one column per proxy-tercile combo)
    combined = []
    for col, results in all_results.items():
        combined.extend(results)

    write_table(combined, "phase8_income_terciles.md",
                "Income Tercile Heterogeneity: 4 Automation Proxies",
                notes=("*I/Y = gross_investment_gdp. K/Y = rnna/rgdpo. "
                       "K/L = capital_per_worker. GDP/cap = gdp_pc_ppp. "
                       "Terciles by GDP per capita (PPP).*"))

    # Summary: Z_1 coefficient by proxy × tercile
    print("\n\n  === Z_1 Summary by Proxy × Tercile ===")
    print(f"  {'Proxy':12s} {'Full':>12s} {'Low':>12s} {'Mid':>12s} {'High':>12s}")
    for col, results in all_results.items():
        row = f"  {col[:12]:12s}"
        for r in results:
            z1 = r.get('Z_1_coef', np.nan)
            z1p = r.get('Z_1_p', 1.0)
            row += f" {z1:8.2f}{stars(z1p):3s}"
        print(row)

    return all_results


# ══════════════════════════════════════════════════════════════════════
# PART 3: MEDIATION/SUPPRESSION WITH CORRECTED LABELS
# ══════════════════════════════════════════════════════════════════════

def part3_mediation_corrected(df):
    """Recompute mediation with proper variable labels and terminology."""
    print("\n" + "=" * 70)
    print("PART 3: MEDIATION/SUPPRESSION ANALYSIS (CORRECTED)")
    print("=" * 70)

    proxies = {
        'capital_intensity': 'I/Y',
        'capital_output_ratio': 'K/Y',
        'capital_per_worker': 'K/L',
    }

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    summary_rows = []

    for sample_label, sub_df in [('Full', df), ('OECD', df[df['iso3'].isin(OECD)])]:
        print(f"\n{'='*50}")
        print(f"  {sample_label} sample (N={len(sub_df)}, "
              f"{sub_df['iso3'].nunique()} countries)")
        print(f"{'='*50}")

        # Baseline: Z → CA
        r_base = run_gls(sub_df, 'ca_gdp', z_vars + CONTROLS_FULL,
                         f'{sample_label}: Z→CA')
        if r_base is None:
            continue

        results_sample = [r_base]

        for col, short in proxies.items():
            if col not in sub_df.columns or sub_df[col].notna().sum() < 50:
                continue

            # Z + mediator → CA
            r_med = run_gls(sub_df, 'ca_gdp',
                            z_vars + [col] + CONTROLS_FULL,
                            f'{sample_label}: Z+{short}→CA')
            if r_med:
                results_sample.append(r_med)

                # Compute attenuation
                b_base = r_base.get('Z_1_coef', 0)
                b_med = r_med.get('Z_1_coef', 0)
                if abs(b_base) > 1e-10:
                    atten = (b_base - b_med) / b_base * 100
                    mechanism = 'Mediation' if 0 < atten < 100 else 'Suppression'
                    print(f"\n  {short} {mechanism}: Z_1 {b_base:.2f} → {b_med:.2f} "
                          f"(atten: {atten:.1f}%)")

                    summary_rows.append({
                        'Mediator': f"{short} ({col})",
                        'Sample': sample_label,
                        'Z_1 baseline': f"{b_base:.4f}{stars(r_base['Z_1_p'])}",
                        'Z_1 controlled': f"{b_med:.4f}{stars(r_med['Z_1_p'])}",
                        'Attenuation': f"{atten:.1f}%",
                        'Pattern': mechanism,
                    })

        if len(results_sample) > 1:
            write_table(results_sample,
                        f"phase8_mediation_{sample_label.lower()}.md",
                        f"{sample_label}: CA Mediation with Corrected Proxies")

    # Summary table
    lines = ["# Mediation/Suppression Summary (Corrected Variable Definitions)\n"]
    lines.append("| Mediator | Sample | Z₁ baseline | Z₁ controlled | Attenuation | Pattern |")
    lines.append("|:---|:---|---:|---:|---:|:---|")
    for row in summary_rows:
        lines.append(f"| {row['Mediator']} | {row['Sample']} | {row['Z_1 baseline']} | "
                     f"{row['Z_1 controlled']} | {row['Attenuation']} | {row['Pattern']} |")

    lines.append("\n*Mediation: 0–100% attenuation (mediator absorbs part of Z effect).*")
    lines.append("*Suppression: <0% or >100% attenuation (mediator reveals opposing channels).*")
    lines.append("*Baron & Kenny (1986) framework with PanelGLS.*")

    path = OUT_TABLES / "phase8_mediation_summary.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ══════════════════════════════════════════════════════════════════════
# PART 4: NEGATIVE R² EXPLANATION
# ══════════════════════════════════════════════════════════════════════

def part4_negative_r2(df):
    """Document which models have negative R² and why."""
    print("\n" + "=" * 70)
    print("PART 4: NEGATIVE R² DOCUMENTATION")
    print("=" * 70)

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    proxies = ['capital_intensity', 'capital_output_ratio', 'capital_per_worker',
               'labor_productivity', 'gvc_proxy']

    rows = []
    for col in proxies:
        if col not in df.columns or df[col].notna().sum() < 100:
            continue

        for sample_label, sub_df in [('Full', df), ('OECD', df[df['iso3'].isin(OECD)]),
                                      ('non-OECD', df[~df['iso3'].isin(OECD)])]:
            r = run_gls(sub_df, col, z_vars + CONTROLS, f"{col}: {sample_label}")
            if r:
                rows.append({
                    'DV': col,
                    'Sample': sample_label,
                    'N': r['n_obs'],
                    'R²': r['r_squared'],
                    'rho': r['rho'],
                    'Flag': '⚠' if r['r_squared'] < 0 else '',
                })

    lines = ["# Model Fit Summary (R² and AR(1) Correction)\n"]
    lines.append("**Reviewer Issue 3.2**: Why do some models report negative R²?\n")
    lines.append("| Dependent Variable | Sample | N | R² | AR(1) ρ | Flag |")
    lines.append("|:---|:---|---:|---:|---:|:---|")
    for row in rows:
        lines.append(f"| {row['DV']} | {row['Sample']} | {row['N']} | "
                     f"{row['R²']:.4f} | {row['rho']:.3f} | {row['Flag']} |")

    lines.append("\n## Interpretation")
    lines.append("")
    lines.append("Negative R² values arise under GLS estimation when the model's fit, after")
    lines.append("Prais-Winsten AR(1) correction and within-group demeaning, is worse than a")
    lines.append("simple within-group mean. This occurs when:")
    lines.append("1. The regressors explain little cross-sectional variation after absorbing")
    lines.append("   country and year fixed effects;")
    lines.append("2. The AR(1) transformation alters the effective dependent variable enough")
    lines.append("   that a model fitting well in levels fits poorly in quasi-differences.")
    lines.append("")
    lines.append("Negative R² does NOT indicate a coding error. The coefficient estimates and")
    lines.append("standard errors remain valid — only the goodness-of-fit measure is")
    lines.append("uninterpretable as variance explained. We report these values transparently")
    lines.append("rather than censoring at zero.")

    path = OUT_TABLES / "phase8_negative_r2.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ══════════════════════════════════════════════════════════════════════
# PART 5: VARIABLE DEFINITION TABLE
# ══════════════════════════════════════════════════════════════════════

def part5_variable_definitions(df):
    """Generate variable definition appendix table."""
    print("\n" + "=" * 70)
    print("PART 5: VARIABLE DEFINITION TABLE")
    print("=" * 70)

    lines = ["# Variable Definitions\n"]
    lines.append("**Reviewer Issue 3.1**: Clarify variable construction.\n")

    lines.append("## Automation Proxies\n")
    lines.append("| Label in Paper | Column Name | Source | Construction | Units |")
    lines.append("|:---|:---|:---|:---|:---|")
    lines.append("| Investment/GDP (I/Y) | capital_intensity | WDI | "
                 "Gross capital formation / GDP | % of GDP |")
    lines.append("| Capital-output ratio (K/Y) | capital_output_ratio | PWT 10.01 | "
                 "rnna / rgdpo (constant 2017 prices) | Ratio |")
    lines.append("| Capital per worker (K/L) | capital_per_worker | PWT 10.01 | "
                 "rnna / (emp × 1M) | Millions per worker |")
    lines.append("| GDP per capita | labor_productivity | WDI | "
                 "GDP per capita, PPP (current intl $) | USD PPP |")
    lines.append("| Labor share | labsh | PWT 10.01 | "
                 "Share of labor compensation in GDP | 0–1 |")
    lines.append("| Capital share | automation_proxy | PWT 10.01 | "
                 "1 − labsh | 0–1 |")
    lines.append("| Trade openness | gvc_proxy | WDI | "
                 "(Exports + Imports) / GDP | % of GDP |")

    # Report coverage
    lines.append("\n## Coverage\n")
    lines.append("| Variable | Non-missing obs | Countries | Years |")
    lines.append("|:---|---:|---:|:---|")
    for col in ['capital_intensity', 'capital_output_ratio', 'capital_per_worker',
                'labor_productivity', 'labsh', 'gvc_proxy']:
        if col in df.columns:
            sub = df[df[col].notna()]
            if len(sub) > 0:
                lines.append(f"| {col} | {len(sub)} | {sub['iso3'].nunique()} | "
                             f"{int(sub['year'].min())}–{int(sub['year'].max())} |")

    lines.append("\n## Key Distinction\n")
    lines.append("**Investment/GDP (I/Y)** measures investment *effort* — the flow of new "
                 "capital formation as a share of current output. It reflects how much of "
                 "output is being directed toward capital accumulation in a given year.")
    lines.append("")
    lines.append("**Capital-output ratio (K/Y)** measures the *stock* of capital relative "
                 "to output. It reflects the accumulated capital intensity of the production "
                 "structure. K/Y rises slowly as investment accumulates and depreciates.")
    lines.append("")
    lines.append("**Capital per worker (K/L)** measures the capital available to each "
                 "worker. It can diverge from K/Y when labor force growth differs from "
                 "output growth (e.g., immigration raises L faster than K).")
    lines.append("")
    lines.append("In earlier drafts, 'capital intensity' referred to I/Y. We now use "
                 "explicit labels throughout to avoid ambiguity.")

    path = OUT_TABLES / "phase8_variable_definitions.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ══════════════════════════════════════════════════════════════════════
# MAIN
# ══════════════════════════════════════════════════════════════════════

def main():
    print("=" * 70)
    print("PHASE 8: REVIEWER #2 RESPONSE — AUTOMATION PAPER")
    print("=" * 70)

    df = prepare_data()

    # Part 1: Corrected baselines
    part1_corrected_baselines(df)

    # Part 2: Income tercile reanalysis
    part2_income_terciles(df)

    # Part 3: Mediation/suppression with corrected labels
    part3_mediation_corrected(df)

    # Part 4: Negative R² documentation
    part4_negative_r2(df)

    # Part 5: Variable definitions
    part5_variable_definitions(df)

    print("\n" + "=" * 70)
    print("PHASE 8 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
